{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n",
        "[Learn the Basics](intro.html) ||\n",
        "[Quickstart](quickstart_tutorial.html) ||\n",
        "[Tensors](tensorqs_tutorial.html) ||\n",
        "[Datasets & DataLoaders](data_tutorial.html) ||\n",
        "[Transforms](transforms_tutorial.html) ||\n",
        "[Build Model](buildmodel_tutorial.html) ||\n",
        "[Autograd](autogradqs_tutorial.html) ||\n",
        "[Optimization](optimization_tutorial.html) ||\n",
        "**Save & Load Model**\n",
        "\n",
        "# Save and Load the Model\n",
        "\n",
        "In this section we will look at how to persist model state with saving, loading and running model predictions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision.models as models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Saving and Loading Model Weights\n",
        "PyTorch models store the learned parameters in an internal\n",
        "state dictionary, called ``state_dict``. These can be persisted via the ``torch.save``\n",
        "method:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = models.vgg16(pretrained=True)\n",
        "torch.save(model.state_dict(), 'model_weights.pth')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To load model weights, you need to create an instance of the same model first, and then load the parameters\n",
        "using ``load_state_dict()`` method.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# we do not specify pretrained=True, i.e.(也就是说) do not load default weights\n",
        "model = models.vgg16()\n",
        "model.load_state_dict(torch.load('model_weights.pth'))\n",
        "model.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.</p></div>\n",
        "\n",
        "在推理之前，请务必调用“Model.eval（）”方法，将 dropout 和批处理规范化层设置为评估模式。如果不这样做，将产生不一致的推理结果。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Saving and Loading Models with Shapes\n",
        "When loading model weights, we needed to instantiate the model class first, because the class\n",
        "defines the structure of a network. We might want to save the structure of this class together with\n",
        "the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "torch.save(model, 'model.pth')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can then load the model like this:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = torch.load('model.pth')\n",
        "print(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>This approach uses Python [pickle](https://docs.python.org/3/library/pickle.html) module when serializing the model, thus it relies on the actual class definition to be available when loading the model.</p></div>\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Related Tutorials\n",
        "[Saving and Loading a General Checkpoint in PyTorch](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.9"
    },
    "vscode": {
      "interpreter": {
        "hash": "cf18841ace8313d0bc088ca146c17a6c0040e82121d5cb75c0ea07172309253d"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
